{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qa_store import QuestionAnswerKB\n",
    "\n",
    "\n",
    "def test_qa_store() -> None:\n",
    "  kb = QuestionAnswerKB()\n",
    "  kb.reset_database()\n",
    "\n",
    "  # Add some question-answer pairs\n",
    "  kb.add_qa(\n",
    "    \"What is the capital of Germany?\", \"Berlin\"\n",
    "  )\n",
    "  kb.add_qa(\n",
    "    \"Who is the president of the USA?\", \"Joe Biden\"\n",
    "  )\n",
    "  kb.add_qa(\n",
    "    \"What is the answer to life, the universe, and everything?\",\n",
    "    42,\n",
    "  )\n",
    "\n",
    "  # Query the KB\n",
    "  results = kb.query(\"What is the capital of Germany?\")\n",
    "  assert results[0][\"answer\"] == \"Berlin\"\n",
    "\n",
    "  results = kb.query(\n",
    "    \"Who is the president of the USA?\"\n",
    "  )\n",
    "  assert results[0][\"answer\"] == \"Joe Biden\"\n",
    "\n",
    "  results = kb.query(\n",
    "    \"What is the answer to life, the universe, and everything?\"\n",
    "  )\n",
    "  assert results[0][\"answer\"] == \"42\"\n",
    "\n",
    "  # Update an answer\n",
    "  kb.update_answer(\n",
    "    \"What is the capital of Germany?\", \"Munich\"\n",
    "  )\n",
    "  results = kb.query(\"What is the capital of Germany?\")\n",
    "  assert results[0][\"answer\"] == \"Munich\"\n",
    "\n",
    "  # Add some metadata\n",
    "  kb.add_qa(\n",
    "    \"What is the capital of France?\",\n",
    "    \"Paris\",\n",
    "    metadata={\"source\": \"Wikipedia\"},\n",
    "  )\n",
    "  results = kb.query(\"What is the capital of France?\")\n",
    "  assert (\n",
    "    results[0][\"metadata\"][\"source\"] == \"Wikipedia\"\n",
    "  )\n",
    "\n",
    "  results = kb.query(\n",
    "    \"What is the capital of France?\",\n",
    "    metadata_filter={\"source\": \"Wikipedia\"},\n",
    "  )\n",
    "  assert (\n",
    "    results[0][\"metadata\"][\"source\"] == \"Wikipedia\"\n",
    "  )\n",
    "\n",
    "  results = kb.query(\n",
    "    \"What is the capital of France?\",\n",
    "    metadata_filter={\"source\": \"github\"},\n",
    "  )\n",
    "  assert not results\n",
    "\n",
    "  # Rewordings\n",
    "  kb.add_qa(\n",
    "    \"What is your favorite TV series\",\n",
    "    \"Band of Brothers\",\n",
    "    num_rewordings=5,\n",
    "  )\n",
    "  results = kb.query(\"What is your favorite TV show\")\n",
    "  assert results[0][\"answer\"] == \"Band of Brothers\"\n",
    "\n",
    "  kb.add_qa(\n",
    "    \"What is your favorite book?\",\n",
    "    \"The Great Gatsby\",\n",
    "  )\n",
    "  results = kb.query(\n",
    "    \"Do you have a book you like the most?\",\n",
    "    num_rewordings=5,\n",
    "  )\n",
    "  assert results[0][\"answer\"] == \"The Great Gatsby\"\n",
    "\n",
    "  print(\"All tests passed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/dothomps/src/witt3rd/ecaa/notebooks/.venv/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm, trange\n",
      "\u001b[32m2024-07-04 11:46:09.593\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mqa_store.main\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m47\u001b[0m - \u001b[1mQuestionAnswerKB initialized for collection 'qa_kb'.\u001b[0m\n",
      "Number of requested results 5 is greater than number of elements in index 3, updating n_results = 3\n",
      "Number of requested results 5 is greater than number of elements in index 3, updating n_results = 3\n",
      "Number of requested results 5 is greater than number of elements in index 3, updating n_results = 3\n",
      "\u001b[32m2024-07-04 11:46:11.230\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mqa_store.main\u001b[0m:\u001b[36mupdate_answer\u001b[0m:\u001b[36m220\u001b[0m - \u001b[1mAnswer to question 'What is the capital of Germany?' has been updated.\u001b[0m\n",
      "Number of requested results 5 is greater than number of elements in index 3, updating n_results = 3\n",
      "Number of requested results 5 is greater than number of elements in index 4, updating n_results = 4\n",
      "Number of requested results 5 is greater than number of elements in index 4, updating n_results = 4\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All tests passed!\n"
     ]
    }
   ],
   "source": [
    "test_qa_store()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Advanced Usage Example\n",
    "\n",
    "This example showcases the use of question rewordings for both adding QA pairs and querying the knowledge base:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/dothomps/src/witt3rd/ecaa/notebooks/.venv/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm, trange\n",
      "\u001b[32m2024-07-04 16:50:51.927\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mqa_store.main\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m72\u001b[0m - \u001b[1mQuestionAnswerKB initialized for collection 'qa_kb'.\u001b[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Added questions:\n",
      "- - How should one go about learning programming effectively?\n",
      "- - What methods are most effective for mastering programming?\n",
      "- - Could you tell me the optimal approach to learning programming?\n",
      "- What is the best way to learn programming?\n",
      "\n",
      "Query results:\n",
      "Question: - How should one go about learning programming effectively?\n",
      "Answer: The best way to learn programming is through consistent practice, working on real projects, and continuous learning.\n",
      "Similarity: 0.39\n",
      "Metadata: {'field': 'computer science', 'topic': 'education'}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Initialize the Knowledge Base\n",
    "kb = QuestionAnswerKB()\n",
    "\n",
    "# Add a question-answer pair with rewordings\n",
    "original_question = (\n",
    "  \"What is the best way to learn programming?\"\n",
    ")\n",
    "answer = \"The best way to learn programming is through consistent practice, working on real projects, and continuous learning.\"\n",
    "\n",
    "added_questions = kb.add_qa(\n",
    "  question=original_question,\n",
    "  answer=answer,\n",
    "  metadata={\n",
    "    \"topic\": \"education\",\n",
    "    \"field\": \"computer science\",\n",
    "  },\n",
    "  num_rewordings=3,\n",
    ")\n",
    "\n",
    "print(\"Added questions:\")\n",
    "for q in added_questions:\n",
    "  print(f\"- {q}\")\n",
    "\n",
    "# Now let's query the Knowledge Base with a different phrasing\n",
    "query_question = (\n",
    "  \"How can I become proficient in coding?\"\n",
    ")\n",
    "\n",
    "results = kb.query(\n",
    "  question=query_question,\n",
    "  n_results=2,\n",
    "  metadata_filter={\"topic\": \"education\"},\n",
    "  num_rewordings=2,\n",
    ")\n",
    "\n",
    "print(\"\\nQuery results:\")\n",
    "for result in results:\n",
    "  print(f\"Question: {result['question']}\")\n",
    "  print(f\"Answer: {result['answer']}\")\n",
    "  print(f\"Similarity: {result['similarity']:.2f}\")\n",
    "  print(f\"Metadata: {result['metadata']}\")\n",
    "  print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
